{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.optim as optim\n",
    "import torch.nn.functional as F\n",
    "from torch.autograd.function import Function\n",
    "import os\n",
    "import pickle\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import math\n",
    "from itertools import chain\n",
    "from sklearn.model_selection import train_test_split\n",
    "import itertools\n",
    "import random\n",
    "import tqdm\n",
    "from IPython.display import display, HTML\n",
    "import matplotlib.pyplot as plt\n",
    "from scipy import signal\n",
    "from scipy.fft import fftshift\n",
    "\n",
    "from einops import rearrange, repeat\n",
    "from einops.layers.torch import Rearrange\n",
    "from torch import nn, einsum\n",
    "import math\n",
    "import logging\n",
    "from functools import partial\n",
    "from collections import OrderedDict\n",
    "from sklearn.metrics import classification_report\n",
    "from torchsummary import summary\n",
    "import gpustat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 53,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/bin/sh: 1: gpustat: not found\r\n"
     ]
    }
   ],
   "source": [
    "!gpustat"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 124,
   "metadata": {},
   "outputs": [],
   "source": [
    "seq_len = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 126,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(790121, 32, 30)\n",
      "(790121, 32, 30)\n",
      "(263373, 3, 2, 32, 30)\n"
     ]
    }
   ],
   "source": [
    "dat1 = np.load('/project/hikaku_db/data/sleep_SHHS/stages_sig/C4_spec_30_np/spec_1_800_30.npy')\n",
    "dat2 = np.load('/project/hikaku_db/data/sleep_SHHS/stages_sig/C3_spec_30_np/spec_c3_1_800_30.npy')\n",
    "print(dat1.shape)\n",
    "print(dat2.shape)\n",
    "dat1 = dat1[:790119, :, :]\n",
    "dat2 = dat2[:790119, :, :]\n",
    "dat_eeg = np.concatenate((dat1.reshape(-1,seq_len,1,32,30), dat2.reshape(-1,seq_len,1,32,30)), axis=2)\n",
    "print(dat_eeg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 127,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "        0\n",
      "2  329696\n",
      "0  221935\n",
      "4  109008\n",
      "3  100831\n",
      "1   28651\n",
      "(263373, 3)\n"
     ]
    }
   ],
   "source": [
    "index = pd.read_csv(\"/project/hikaku_db/data/sleep_SHHS/stages_sig/ann_delrecords_5class.csv\", header=None)\n",
    "print(index[0 : 790121].apply(pd.value_counts))\n",
    "label = index[:790119].values\n",
    "label = label.reshape(-1,seq_len)\n",
    "print(label.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 128,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(263373, 3, 2, 20, 30)\n"
     ]
    }
   ],
   "source": [
    "fixdata = dat_eeg[:,:,:,0:16,:]\n",
    "mean_p1 = np.mean(dat_eeg[:,:,:,16:20,:], axis = 3)\n",
    "mean_p2 = np.mean(dat_eeg[:,:,:,20:24,:], axis = 3)\n",
    "mean_p3 = np.mean(dat_eeg[:,:,:,24:28,:], axis = 3)\n",
    "mean_p4 = np.mean(dat_eeg[:,:,:,28:32,:], axis = 3)\n",
    "num_data = len(dat_eeg)\n",
    "ch = 2\n",
    "inputeeg = np.concatenate((fixdata,mean_p1.reshape(-1, seq_len, ch, 1, 30),mean_p2.reshape(-1, seq_len, ch, 1, 30),mean_p3.reshape(-1, seq_len, ch, 1, 30),mean_p4.reshape(-1, seq_len, ch, 1, 30)),axis=3)\n",
    "print(inputeeg.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 131,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Mydatasets(torch.utils.data.Dataset):\n",
    "    def __init__(self, data1, label ,transform = None):\n",
    "        self.transform = transform\n",
    "        self.data1 = data1\n",
    "        self.label = label\n",
    "        self.datanum = len(data1)\n",
    "\n",
    "    def __len__(self):\n",
    "        return self.datanum\n",
    "\n",
    "    def __getitem__(self, idx):\n",
    "        \n",
    "        out_data1 = torch.tensor(self.data1[idx]).float()\n",
    "        out_label = torch.tensor(self.label[idx])\n",
    "        if self.transform:\n",
    "            out_data1 = self.transform(out_data1)\n",
    "\n",
    "        return out_data1, out_label"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 132,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "train eeg data: 237035\n",
      "test eeg data: 26338\n"
     ]
    }
   ],
   "source": [
    "train_eeg, test_eeg, train_label, test_label = train_test_split(inputeeg, label, test_size = 0.1,random_state = 66)\n",
    "print('train eeg data:',len(train_eeg))\n",
    "print('test eeg data:',len(test_eeg))\n",
    "\n",
    "train_data_set = Mydatasets(data = train_eeg, label = train_label)\n",
    "test_data_set = Mydatasets(data = test_eeg, label = test_label)\n",
    "\n",
    "train_dataloader = torch.utils.data.DataLoader(train_data_set, batch_size = 16, shuffle=True)\n",
    "test_dataloader = torch.utils.data.DataLoader(test_data_set, batch_size = 16, shuffle=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 148,
   "metadata": {},
   "outputs": [],
   "source": [
    "##Seq_VTransformer model\n",
    "\n",
    "class PreNorm(nn.Module): \n",
    "    def __init__(self, dim, fn):\n",
    "        super().__init__()\n",
    "        self.norm = nn.LayerNorm(dim)\n",
    "        self.fn = fn\n",
    "    def forward(self, x, **kwargs):\n",
    "        return self.fn(self.norm(x), **kwargs)\n",
    "\n",
    "class FeedForward(nn.Module):\n",
    "    def __init__(self, dim, mlp_dim, dropout):\n",
    "        super().__init__()\n",
    "        self.net = nn.Sequential(\n",
    "            nn.Linear(dim, mlp_dim),       #output_of_attention_dim to mlp_dim\n",
    "            nn.GELU(),\n",
    "            nn.Dropout(dropout),\n",
    "            nn.Linear(mlp_dim, dim),      #mlp_dim to output_of_attention_dim(==input_of_attention_dim)\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "    def forward(self, x):\n",
    "        return self.net(x)\n",
    "\n",
    "class Attention(nn.Module):\n",
    "    def __init__(self, dim, dropout, heads = 4, dim_head = 8):\n",
    "        super().__init__()\n",
    "        inner_dim = dim_head * heads       ##qkv_dim(inner_dim) = head_num * head_dim     32 = 4*8\n",
    "        project_out = not (heads == 1 and dim_head == dim)\n",
    "\n",
    "        self.heads = heads\n",
    "        self.scale = dim_head ** -0.5\n",
    "\n",
    "        self.attend = nn.Softmax(dim = -1)\n",
    "        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)   ##embedding_dim to qkv_dim * 3\n",
    "\n",
    "        self.to_out = nn.Sequential(\n",
    "            nn.Linear(inner_dim, dim),\n",
    "            nn.Dropout(dropout)\n",
    "        ) if project_out else nn.Identity()\n",
    "\n",
    "    def forward(self, x):\n",
    "        b, n, _, h = *x.shape, self.heads\n",
    "        qkv = self.to_qkv(x).chunk(3, dim = -1)\n",
    "        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)\n",
    "\n",
    "        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale #dots for attention-value(scaled) of Q -> K\n",
    "        \n",
    "        attn = self.attend(dots)             #atte for Softmax(attention-value)\n",
    "\n",
    "        out = einsum('b h i j, b h j d -> b h i d', attn, v)        #out for Z = attn * V\n",
    "        out = rearrange(out, 'b h n d -> b n (h d)')\n",
    "        return self.to_out(out)             #out_dim(==inner_dim==Z_dim) to embedding_dim\n",
    "\n",
    "class Transformer(nn.Module):           ##Register the blocks into whole network\n",
    "    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):\n",
    "        super().__init__()\n",
    "        self.layers = nn.ModuleList([])\n",
    "        for _ in range(depth):\n",
    "            self.layers.append(nn.ModuleList([\n",
    "                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),\n",
    "                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))\n",
    "            ]))\n",
    "    def forward(self, x):\n",
    "        for attn, ff in self.layers:\n",
    "            x = attn(x) + x     #Residuals cut-in\n",
    "            x = ff(x) + x\n",
    "        return x\n",
    "\n",
    "class ViT(nn.Module):\n",
    "    def __init__(self,image_size, time_size, fre_size, dim, channels , dim_head,depth, heads, mlp_dim, dropout, pool = 'cls', emb_dropout = 0.):\n",
    "        super().__init__()\n",
    "        assert image_size == 30  ##Time dimensions must equal to 30s\n",
    "        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'\n",
    "        num_patches = 150       #30*5(5EEG)\n",
    "        patch_dim = channels * time_size * fre_size    #EEG_patch_dim:2*1*4  \n",
    "    \n",
    "        self.eeg_to_patch_embedding = nn.Sequential(\n",
    "            Rearrange('b c (h p3) (w p4) -> b (h w) (p3 p4 c)', p3 = fre_size, p4 = time_size),\n",
    "            nn.Linear(patch_dim, dim),       ##eeg_patch_dim(2*1*4) to embed_dim\n",
    "        )\n",
    "\n",
    "        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim)) ##Generate the pos value'\n",
    "        self.cls_token = nn.Parameter(torch.randn(1, 1, dim)) ##Generate the class value'\n",
    "        self.dropout = nn.Dropout(emb_dropout)\n",
    "        self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout) \n",
    "        #self.seq_transformer = Transformer(dim*2, depth, heads, dim_head, mlp_dim, dropout)# ill defined,useless\n",
    "        \n",
    "        self.pool = pool\n",
    "        self.dim = dim\n",
    "        self.to_latent = nn.Identity()\n",
    "        \n",
    "\n",
    "\n",
    "    def forward(self, input_eeg):\n",
    "        dim = self.dim\n",
    "        batch = input_eeg.shape[0]\n",
    "        cls_token_seq = torch.empty(batch,0,dim).to(DEVICE)\n",
    "        for seq in range(seq_len):\n",
    "            eeg = input_eeg[:,seq,:,:,:]\n",
    "            EEG = self.eeg_to_patch_embedding(eeg)\n",
    "            x = EEG            \n",
    "            b, n, _ = x.shape\n",
    "            cls_tokens = repeat(self.cls_token, '(      ) n d -> b n d', b = b)\n",
    "            x = torch.cat((cls_tokens, x), dim=1)\n",
    "            x += self.pos_embedding[:, :(n + 1)]\n",
    "            x = self.dropout(x)\n",
    "            x = self.transformer(x)\n",
    "            x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]\n",
    "            x = self.to_latent(x)\n",
    "            x = x.reshape(batch,-1,dim) \n",
    "            cls_token_seq = torch.cat((cls_token_seq, x), dim=1)            \n",
    "        return cls_token_seq   \n",
    "\n",
    "class seq_ViT(nn.Module):\n",
    "    def __init__(self,num_classes, seq_len, dim, dim_head, depth, heads, mlp_dim, dropout, pool = 'cls', emb_dropout = 0.):\n",
    "        super().__init__()\n",
    "        num_patches = seq_len       \n",
    "        seq_dim = dim \n",
    "        embedding_dim = dim*2\n",
    "        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'\n",
    "        \n",
    "        self.seq_to_patch_embedding = nn.Sequential(\n",
    "            nn.Linear(seq_dim, embedding_dim),      \n",
    "        )\n",
    "\n",
    "        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, embedding_dim)) ##Generate the pos value'\n",
    "        self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim)) ##Generate the class value'\n",
    "        self.dropout = nn.Dropout(emb_dropout)\n",
    "        self.transformer = Transformer(embedding_dim, depth, heads, dim_head, mlp_dim, dropout) \n",
    "        \n",
    "        self.pool = pool\n",
    "        self.dim = dim\n",
    "        self.to_latent = nn.Identity()\n",
    "        \n",
    "        self.mlp_head = nn.Sequential(\n",
    "            nn.LayerNorm(embedding_dim),\n",
    "            nn.Linear(embedding_dim, num_classes),\n",
    "            nn.Dropout(dropout)\n",
    "        )\n",
    "\n",
    "    def forward(self, input_seq):\n",
    "        seq = self.seq_to_patch_embedding(input_seq)         \n",
    "        b, seq_len, n = seq.shape\n",
    "        cls_tokens = repeat(self.cls_token, '(      ) n d -> b n d', b = b)\n",
    "        seq = torch.cat((cls_tokens, seq), dim=1)\n",
    "        seq += self.pos_embedding[:, :(n + 1)]\n",
    "        seq = self.dropout(seq)\n",
    "        seq = self.transformer(seq)\n",
    "        seq = seq.mean(dim = 1) if self.pool == 'mean' else seq[:, 0]\n",
    "        mlp_out = self.mlp_head(seq)         \n",
    "        return F.log_softmax(mlp_out, dim = -1)   \n",
    "    \n",
    "class Seq_VTrans(nn.Module):\n",
    "    def __init__(self, *, seq_len,num_classes,image_size, time_size, fre_size, dim, channels , dim_head,depth, heads, mlp_dim, dropout, pool = 'cls', emb_dropout = 0.): \n",
    "        super().__init__()\n",
    "        self.trans = ViT(image_size, time_size, fre_size, dim, channels, dim_head, depth, heads, mlp_dim, dropout, pool, emb_dropout)\n",
    "        self.seq_trans = seq_ViT(num_classes, seq_len, dim, dim_head, depth, heads, mlp_dim, dropout, pool, emb_dropout)\n",
    "        \n",
    "    def forward(self,input_eeg):\n",
    "        cls_token_seq = self.trans(input_eeg)\n",
    "        out = self.seq_trans(cls_token_seq)\n",
    "        return out"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda:0\n"
     ]
    }
   ],
   "source": [
    "DEVICE = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "print(DEVICE)\n",
    "\n",
    "Seq_VTransmodel = Seq_VTrans(\n",
    "    image_size = 30, \n",
    "    time_size = 1, \n",
    "    fre_size = 4, \n",
    "    num_classes = 5,\n",
    "    channels = 2,\n",
    "    seq_len = 3,   \n",
    "    depth = 8,   \n",
    "    heads = 4, \n",
    "    dim = 32,     #32\n",
    "    dim_head = 32,    #32\n",
    "    mlp_dim = 128,     #128\n",
    "    dropout = 0.0\n",
    ").to(DEVICE)\n",
    "\n",
    "optimizer = torch.optim.AdamW(Seq_VTransmodel.parameters(), lr = 1e-6)\n",
    "Classifier_loss = torch.nn.CrossEntropyLoss()\n",
    "\n",
    "#Seq_VTransmodel.load_state_dict(torch.load('/project/hikaku_db/ziwei/Model_VT_VT/Model_VT_VT_state_diff_shape_1'))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "EPOCH = 100\n",
    "loss_list=[]\n",
    "los_min=10**10\n",
    "val_loss_list=[]\n",
    "ac_list=[]\n",
    "\n",
    "\n",
    "for epoch in tqdm.tqdm(range(EPOCH)):\n",
    "    running_loss = 0.0\n",
    "    count=0\n",
    "    for _, (inputs, labels) in enumerate(train_dataloader, 0):\n",
    "        inputs = inputs.to(DEVICE)\n",
    "        labels = labels.to(DEVICE)\n",
    "        label_one_hot = F.one_hot(labels, 5).float()\n",
    "        cls_out = Seq_VTransmodel(inputs)\n",
    "        count = count+1      \n",
    "        loss = Classifier_loss(cls_out, label_one_hot[:,1,:])\n",
    "        loss.backward()\n",
    "        optimizer.step() \n",
    "        running_loss += loss.item()\n",
    "    loss_loss=running_loss/count\n",
    "    loss_list.append(loss_loss)\n",
    "    print('epoch',epoch+1,':finished')\n",
    "    print('train_loss:',loss_loss)\n",
    "\n",
    "    #####test aphase \n",
    "    with torch.no_grad():\n",
    "        count=0\n",
    "        running_loss=0.0\n",
    "        pre=list()\n",
    "        lab=list()\n",
    "        for _, (inputs, labels) in enumerate(test_dataloader, 0):\n",
    "            inputs = inputs.to(DEVICE)\n",
    "            labels = labels.to(DEVICE)\n",
    "            label_one_hot = F.one_hot(labels, 5).float()\n",
    "            cls_out = Seq_VTransmodel(inputs)\n",
    "            count = count+1      \n",
    "            loss = Classifier_loss(cls_out, label_one_hot[:,1,:]) \n",
    "            running_loss += loss.item()\n",
    "            predicted = cls_out.max(1)[1].to('cpu') \n",
    "            groudT = label_one_hot[:,1,:].max(1)[1].to('cpu')\n",
    "            predicted = predicted.tolist()\n",
    "            groudT = groudT.tolist()\n",
    "            pre.append(predicted)\n",
    "            lab.append(groudT)\n",
    "        loss_loss = running_loss/count\n",
    "        val_loss_list.append(loss_loss)\n",
    "        pre=sum(pre,[])\n",
    "        lab=sum(lab,[])\n",
    "        print('val_loss:',loss_loss)\n",
    "        cl = classification_report(lab, pre,output_dict=True)\n",
    "        print(cl)\n",
    "        ac_list.append(cl['accuracy'])\n",
    "        \n",
    "        #torch.save(Seq_VTransmodel.state_dict(),'/project/hikaku_db/ziwei/Model_VT_VT/Model_VT_VT_state_3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(loss_list[1:])\n",
    "plt.plot(val_loss_list[1:])\n",
    "#plt.savefig('/project/hikaku_db/ziwei/Model_VT_VT/loss_plt_diff_shape_1.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "plt.plot(ac_list)\n",
    "#plt.savefig('/project/hikaku_db/ziwei/Model_VT_VT/ac_plt_diff_shape_1.jpg')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
